{
  "cells": [
    {
      "metadata": {
        "id": "d0ABTeSA_QZh"
      },
      "cell_type": "markdown",
      "source": [
        "# Shardy Guide for JAX Users\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][jax-sdy-tutorial]\n",
        "\n",
        "[jax-sdy-tutorial]: https://colab.research.google.com/github/openxla/shardy/blob/main/docs/getting_started_jax.ipynb\n"
      ]
    },
    {
      "metadata": {
        "id": "c8HlD6EELrYt"
      },
      "cell_type": "markdown",
      "source": [
        "Shardy is a new propagation system being introduced into the XLA stack, and below we want to introduce any JAX users to:\n",
        "\n",
        "1. What has changed in JAX\n",
        "2. Why Shardy?\n",
        "3. Future plans\n",
        "\n",
        "This is meant for JAX users who use `jax.jit` for running training/inference models across more than 1 GPU or TPU (batch parallelism, megatron, ZeRO, etc). They would be using things like `PartitionSpec`s and `NamedSharding`s."
      ]
    },
    {
      "metadata": {
        "id": "6l9mbfCXLrYu"
      },
      "cell_type": "markdown",
      "source": [
        "## 1. What has changed in JAX?"
      ]
    },
    {
      "metadata": {
        "id": "xRyBxdUsLrYv"
      },
      "cell_type": "markdown",
      "source": [
        "### State of JAX before: GSPMD"
      ]
    },
    {
      "metadata": {
        "id": "RhCIp7cXLrYw"
      },
      "cell_type": "markdown",
      "source": [
        "Prior to Shardy, JAX users who partitioned their models across multiple devices used [GSPMD](https://arxiv.org/abs/2105.04663) behind the scenes.\n",
        "\n",
        "GSPMD is the propagation+partitioning system that lives in the middle of the XLA pipeline. It operates on HLO - the IR that comes after StableHLO (the program you get after running `jax.jit.lower`).\n",
        "\n",
        "JAX doesn't run GSPMD directly, but encodes instructions into the StableHLO IR for GSPMD to read later on.\n",
        "\n",
        "But before we go any further, let's introduce our working example.\n",
        "\n",
        "Make sure you have installed `jax\u003e=0.4.35`."
      ]
    },
    {
      "metadata": {
        "id": "anTB6ODVHsGg"
      },
      "cell_type": "code",
      "source": [
        "!pip install jax==0.4.35"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "4flIPFVmAYNw"
      },
      "cell_type": "code",
      "source": [
        "#@title Imports and utilities\n",
        "\n",
        "import os\n",
        "# make sure our code runs on 8 devices\n",
        "os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n",
        "\n",
        "import jax\n",
        "import numpy as np\n",
        "from jax import numpy as jnp\n",
        "from jax.sharding import Mesh, NamedSharding, PartitionSpec\n",
        "from jax.experimental.shard_map import shard_map"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "fl3cXGo9CRMo"
      },
      "cell_type": "markdown",
      "source": [
        "First, let's create our mesh."
      ]
    },
    {
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9NTq9HXCE5OQ",
        "outputId": "4ffb4bc7-b1e8-4ea7-d706-2758267625f1"
      },
      "cell_type": "code",
      "source": [
        "mesh = Mesh(\n",
        "    np.reshape(np.array(jax.devices()), (4, 2)),\n",
        "    ('data', 'model'))\n",
        "\n",
        "print(mesh.shape)"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "OrderedDict([('data', 4), ('model', 2)])\n"
          ]
        }
      ],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "YjgahNBpNu19"
      },
      "cell_type": "markdown",
      "source": [
        "### In/Out shardings\n",
        "\n",
        "Let's look at what changed the most: how sharding attributes are encoded in the JAX program for the compiler to read.\n",
        "\n",
        "Let's look at it through an example. It's going to be an MLP-like model consisting of no bias tensors, and 2 layers (two matmuls)."
      ]
    },
    {
      "metadata": {
        "id": "A--HECIHAe5E"
      },
      "cell_type": "code",
      "source": [
        "def predict(x, w1, w2):\n",
        "  x = jnp.tanh(x)\n",
        "  z1 = jnp.einsum('ij,jk-\u003eik', x, w1)\n",
        "  z2 = jnp.einsum('ij,jk-\u003eik', z1, w2)\n",
        "  return jnp.sin(z2)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "kLR4IUmXGb7b"
      },
      "cell_type": "markdown",
      "source": [
        "What we will want to do here sharding wise is:\n",
        "\n",
        "1. `data` parallelism on x\n",
        "2. `tensor` parallelism on `w1` and `w2` through the [megatron](https://arxiv.org/abs/1909.08053) sharding strategy.\n",
        "\n",
        "Now let's prepare the model for GSPMD sharding. Note that we will explicitly shard `w1`, but let GSPMD propagation shard `w2`."
      ]
    },
    {
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nZaM4mVKF1UQ",
        "outputId": "c9f7b042-1cc1-4743-a475-8f325ffd2eac"
      },
      "cell_type": "code",
      "source": [
        "def run_in_out_shardings():\n",
        "  samples = jax.ShapeDtypeStruct((16, 128), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec('data', None)))\n",
        "  samples_sharding = NamedSharding(mesh, PartitionSpec('data', None))\n",
        "  w1 = jax.ShapeDtypeStruct((128, 256), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec(None, 'model')))\n",
        "  w1_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))\n",
        "  w2 = jax.ShapeDtypeStruct((256, 10), jnp.float32)\n",
        "  w2_sharding = None\n",
        "\n",
        "  print(jax.jit(predict, in_shardings=(samples_sharding, w1_sharding, w2_sharding)).lower(samples, w1, w2).as_text())\n",
        "\n",
        "run_in_out_shardings()"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  func.func public @main(%arg0: tensor\u003c16x128xf32\u003e {mhlo.sharding = \"{devices=[4,1,2]\u003c=[8] last_tile_dim_replicate}\"}, %arg1: tensor\u003c128x256xf32\u003e {mhlo.sharding = \"{devices=[1,2,4]\u003c=[4,2]T(1,0) last_tile_dim_replicate}\"}, %arg2: tensor\u003c256x10xf32\u003e) -\u003e (tensor\u003c16x10xf32\u003e {jax.result_info = \"\"}) {\n",
            "    %0 = stablehlo.tanh %arg0 : tensor\u003c16x128xf32\u003e\n",
            "    %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor\u003c16x128xf32\u003e, tensor\u003c128x256xf32\u003e) -\u003e tensor\u003c16x256xf32\u003e\n",
            "    %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor\u003c16x256xf32\u003e, tensor\u003c256x10xf32\u003e) -\u003e tensor\u003c16x10xf32\u003e\n",
            "    %3 = stablehlo.sine %2 : tensor\u003c16x10xf32\u003e\n",
            "    return %3 : tensor\u003c16x10xf32\u003e\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "B5oXRrVoPh8P"
      },
      "cell_type": "markdown",
      "source": [
        "GSPMD's sharding annotations look like the following:\n",
        "\n",
        "| JAX sharding    | GSPMD sharding |\n",
        "| -------- | ------- |\n",
        "| `NamedSharding(mesh, PartitionSpec('data', None))`  | `{devices=[4,1,2]\u003c=[8] last_tile_dim_replicate}`    |\n",
        "| `NamedSharding(mesh, PartitionSpec(None, 'model'))` | `{devices=[1,2,4]\u003c=[4,2]T(1,0) last_tile_dim_replicate}`     |\n",
        "| `None`    | nothing    |\n",
        "\n",
        "`None` is no sharding as expected since GSPMD will populate this during sharding propagation.\n",
        "\n",
        "Notice how all the axis names go away? While there is a 1:1 correspondence between `NamedSharding` and GSPMD sharding, as a reader, it can be difficult to read. It is only more difficult once you introduce various axis names.\n",
        "\n",
        "Let's look at Shardy for comparison. To enable Shardy in JAX, simply enable the flag:"
      ]
    },
    {
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_k7sNUIOQsxm",
        "outputId": "21e13c60-416e-43e5-bdbd-2942b8b3b0a6"
      },
      "cell_type": "code",
      "source": [
        "jax.config.update(\"jax_use_shardy_partitioner\", True)\n",
        "run_in_out_shardings()\n",
        "jax.config.update(\"jax_use_shardy_partitioner\", False)"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  sdy.mesh @mesh = \u003c[\"data\"=4, \"model\"=2]\u003e\n",
            "  func.func public @main(%arg0: tensor\u003c16x128xf32\u003e {sdy.sharding = #sdy.sharding\u003c@mesh, [{\"data\"}, {}]\u003e}, %arg1: tensor\u003c128x256xf32\u003e {sdy.sharding = #sdy.sharding\u003c@mesh, [{}, {\"model\"}]\u003e}, %arg2: tensor\u003c256x10xf32\u003e) -\u003e (tensor\u003c16x10xf32\u003e {jax.result_info = \"\"}) {\n",
            "    %0 = stablehlo.tanh %arg0 : tensor\u003c16x128xf32\u003e\n",
            "    %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor\u003c16x128xf32\u003e, tensor\u003c128x256xf32\u003e) -\u003e tensor\u003c16x256xf32\u003e\n",
            "    %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor\u003c16x256xf32\u003e, tensor\u003c256x10xf32\u003e) -\u003e tensor\u003c16x10xf32\u003e\n",
            "    %3 = stablehlo.sine %2 : tensor\u003c16x10xf32\u003e\n",
            "    return %3 : tensor\u003c16x10xf32\u003e\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "YAhl-fiERfOQ"
      },
      "cell_type": "markdown",
      "source": [
        "Now we have\n",
        "\n",
        "| JAX sharding    | Shardy sharding |\n",
        "| -------- | ------- |\n",
        "| `NamedSharding(mesh, PartitionSpec('data', None))`  | `#sdy.sharding\u003c@mesh, [{\"data\"}, {}]\u003e`    |\n",
        "| `NamedSharding(mesh, PartitionSpec(None, 'model'))` | `#sdy.sharding\u003c@mesh, [{}, {\"model\"}]\u003e`     |\n",
        "| `None`    | nothing    |\n",
        "\n",
        "Shardy's representation is a lot closer to what JAX `NamedSharding`s are like. So when looking at a file dump of your program after propagation, it will be a lot easier to understand what is going on since the correspondence is a lot closer to JAX.\n",
        "\n",
        "Note that instead of the total devices/axes living on the sharding, they live on a top level `@mesh` value."
      ]
    },
    {
      "metadata": {
        "id": "rg8XG7y9SEVJ"
      },
      "cell_type": "markdown",
      "source": [
        "### `jax.lax.with_sharding_constraint`\n",
        "\n",
        "GSPMD currently lowers it to a custom call:"
      ]
    },
    {
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1O_PB0CDa-OV",
        "outputId": "733ba26e-0bcd-4454-a8aa-381433b7d66a"
      },
      "cell_type": "code",
      "source": [
        "def run_with_sharding_constraint():\n",
        "  x = jax.ShapeDtypeStruct((32, 64), jnp.float32)\n",
        "\n",
        "  def f(x):\n",
        "    return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, PartitionSpec('data', PartitionSpec.UNCONSTRAINED)))\n",
        "\n",
        "  print(jax.jit(f).lower(x).as_text())\n",
        "\n",
        "run_with_sharding_constraint()"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  func.func public @main(%arg0: tensor\u003c32x64xf32\u003e) -\u003e (tensor\u003c32x64xf32\u003e {jax.result_info = \"\"}) {\n",
            "    %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = \"unspecified_dims=[1]\", mhlo.sharding = \"{devices=[4,1,2]\u003c=[8] last_tile_dim_replicate}\"} : (tensor\u003c32x64xf32\u003e) -\u003e tensor\u003c32x64xf32\u003e\n",
            "    return %0 : tensor\u003c32x64xf32\u003e\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "d5pLLzvsbf6O"
      },
      "cell_type": "markdown",
      "source": [
        "But under Shardy it's an explicit op:"
      ]
    },
    {
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "haybWp_0bxI8",
        "outputId": "4a450ea7-5f35-4c3d-a06a-f254c89ce9b1"
      },
      "cell_type": "code",
      "source": [
        "jax.config.update(\"jax_use_shardy_partitioner\", True)\n",
        "run_with_sharding_constraint()\n",
        "jax.config.update(\"jax_use_shardy_partitioner\", False)"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  sdy.mesh @mesh = \u003c[\"data\"=4, \"model\"=2]\u003e\n",
            "  func.func public @main(%arg0: tensor\u003c32x64xf32\u003e) -\u003e (tensor\u003c32x64xf32\u003e {jax.result_info = \"\"}) {\n",
            "    %0 = sdy.sharding_constraint %arg0 \u003c@mesh, [{\"data\"}, {?}]\u003e : tensor\u003c32x64xf32\u003e\n",
            "    return %0 : tensor\u003c32x64xf32\u003e\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "kXxwCyk6KbT-"
      },
      "cell_type": "markdown",
      "source": [
        "Note that `UNCONSTRAINED` under GSPMD has the custom call have an op attribute `backend_config = \"unspecified_dims=[1]\"`. But under Shardy, it makes dim 1 be `{?}`. In Shardy, dimension shardings without a `?` are closed, meaning that dimension can't be further sharded, but when it has a trailing `?`, it can be further sharded. Refer to [Sharding representation](https://github.com/openxla/shardy/tree/main/docs/sharding_representation.md) for more info on the sharding representation."
      ]
    },
    {
      "metadata": {
        "id": "lpHQaPcob5i5"
      },
      "cell_type": "markdown",
      "source": [
        "### `jax.experimental.shard_map`\n",
        "\n",
        "Under GSPMD this is a few different custom calls with various `shard_map` specific attributes on the GSPMD sharding. Let's look where the `model` axis is `auto`, meaning it's free to be used inside the body of the shard_map by sharding constraints."
      ]
    },
    {
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SWWjHMLmcNcW",
        "outputId": "00cf1e63-80ed-4c7c-fa44-844ac8739ccb"
      },
      "cell_type": "code",
      "source": [
        "def run_shard_map():\n",
        "  x = jax.ShapeDtypeStruct((32, 64), jnp.float32)\n",
        "\n",
        "  def body(x):\n",
        "    return jax.lax.all_gather(x, 'data', tiled=True)\n",
        "\n",
        "  shmaped_f = shard_map(\n",
        "        body,\n",
        "        mesh=mesh,\n",
        "        in_specs=(jax.sharding.PartitionSpec('data',),),\n",
        "        out_specs=jax.sharding.PartitionSpec(),\n",
        "        check_rep=False)\n",
        "\n",
        "  print(jax.jit(shmaped_f).lower(x).as_text())\n",
        "\n",
        "print(run_shard_map())"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  func.func public @main(%arg0: tensor\u003c32x64xf32\u003e) -\u003e (tensor\u003c32x64xf32\u003e {jax.result_info = \"\"}) {\n",
            "    %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = \"\", mhlo.sharding = \"{devices=[4,1,2]\u003c=[8] last_tile_dim_replicate}\"} : (tensor\u003c32x64xf32\u003e) -\u003e tensor\u003c32x64xf32\u003e\n",
            "    %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = \"\", mhlo.sharding = \"{manual}\"} : (tensor\u003c32x64xf32\u003e) -\u003e tensor\u003c8x64xf32\u003e\n",
            "    %2 = call @shmap_body(%1) : (tensor\u003c8x64xf32\u003e) -\u003e tensor\u003c32x64xf32\u003e\n",
            "    %3 = stablehlo.custom_call @Sharding(%2) {backend_config = \"\", mhlo.sharding = \"{manual}\"} : (tensor\u003c32x64xf32\u003e) -\u003e tensor\u003c32x64xf32\u003e\n",
            "    %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = \"\", mhlo.sharding = \"{replicated}\"} : (tensor\u003c32x64xf32\u003e) -\u003e tensor\u003c32x64xf32\u003e\n",
            "    return %4 : tensor\u003c32x64xf32\u003e\n",
            "  }\n",
            "  func.func private @shmap_body(%arg0: tensor\u003c8x64xf32\u003e) -\u003e (tensor\u003c32x64xf32\u003e {jax.result_info = \"[None, None]\"}) {\n",
            "    %0 = \"stablehlo.all_gather\"(%arg0) \u003c{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle\u003chandle = 1, type = 1\u003e, replica_groups = dense\u003c[[0, 2, 4, 6], [1, 3, 5, 7]]\u003e : tensor\u003c2x4xi64\u003e, use_global_device_ids}\u003e : (tensor\u003c8x64xf32\u003e) -\u003e tensor\u003c32x64xf32\u003e\n",
            "    return %0 : tensor\u003c32x64xf32\u003e\n",
            "  }\n",
            "}\n",
            "\n",
            "None\n"
          ]
        }
      ],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "yug6aSqXc_Sv"
      },
      "cell_type": "markdown",
      "source": [
        "With the custom calls and GSPMD sharding, it's getting pretty confusing. Let's look at what Shardy gives:"
      ]
    },
    {
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1hCGbWnbdHa-",
        "outputId": "bd6351f9-f430-48f1-c37f-30da033925ad"
      },
      "cell_type": "code",
      "source": [
        "jax.config.update(\"jax_use_shardy_partitioner\", True)\n",
        "run_shard_map()\n",
        "jax.config.update(\"jax_use_shardy_partitioner\", False)"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  sdy.mesh @mesh = \u003c[\"data\"=4, \"model\"=2]\u003e\n",
            "  func.func public @main(%arg0: tensor\u003c32x64xf32\u003e) -\u003e (tensor\u003c32x64xf32\u003e {jax.result_info = \"\"}) {\n",
            "    %0 = sdy.manual_computation(%arg0) in_shardings=[\u003c@mesh, [{\"data\"}, {}]\u003e] out_shardings=[\u003c@mesh, [{}, {}]\u003e] manual_axes={\"data\", \"model\"} (%arg1: tensor\u003c8x64xf32\u003e) {\n",
            "      %1 = \"stablehlo.all_gather\"(%arg1) \u003c{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle\u003chandle = 1, type = 1\u003e, replica_groups = dense\u003c[[0, 2, 4, 6], [1, 3, 5, 7]]\u003e : tensor\u003c2x4xi64\u003e, use_global_device_ids}\u003e : (tensor\u003c8x64xf32\u003e) -\u003e tensor\u003c32x64xf32\u003e\n",
            "      sdy.return %1 : tensor\u003c32x64xf32\u003e\n",
            "    } : (tensor\u003c32x64xf32\u003e) -\u003e tensor\u003c32x64xf32\u003e\n",
            "    return %0 : tensor\u003c32x64xf32\u003e\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "f55laSGGdMP2"
      },
      "cell_type": "markdown",
      "source": [
        "We now:\n",
        "\n",
        "- Have a single op called `sdy.manual_computation` which holds:\n",
        "  - the `in_specs`\n",
        "  - the `out_specs`\n",
        "  - the body of the shard_map\n",
        "  - the inverse of the `auto` axes which we call `manual_axes`\n",
        "\n",
        "A lot easier to read!"
      ]
    },
    {
      "metadata": {
        "id": "ZgeWaxRgQZpI"
      },
      "cell_type": "markdown",
      "source": [
        "### `jax.experimental.custom_partitioning`\n",
        "\n",
        "With GSPMD, we define two routines, `propagate_user_sharding` and `infer_sharding_from_operands`, that may traverse jaxpr to return the sharding for the operands and results in order to use custom_partitioning. With Shardy, we provide `sharding_rule` corresponding to an Einsum like notation string to specify a sharding rule. Here is an example, where the routine that we use custom partition for implements a batch matrix multiplication."
      ]
    },
    {
      "metadata": {
        "id": "Q5Lkp8UaknwB"
      },
      "cell_type": "markdown",
      "source": [
        "We use a device array of (2M, M) to compute a matmul with the form of (...4N, 2N) x (...2N, 4N). Notice that instead of hard-coding the device array and the matrix shapes, we introduce two parameters, M and N, for specifying the shapes of the matrixes and the shapes of the device array so that we can change with these values to fit your purpose.\n",
        "\n",
        "We first perform the needed setup and define the `partition` routine as we would do with GSPMD."
      ]
    },
    {
      "metadata": {
        "id": "ar9YL9ftQZpI"
      },
      "cell_type": "code",
      "source": [
        "from functools import partial\n",
        "from jax.experimental.custom_partitioning import custom_partitioning, SdyShardingRule, BATCHING\n",
        "\n",
        "jax.config.update(\"jax_use_shardy_partitioner\", True)\n",
        "\n",
        "def partition(mesh, arg_shapes, result_shape):\n",
        "  arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)\n",
        "  result_sharding = result_shape.sharding\n",
        "  rank=len(arg_shapes[0].shape)\n",
        "\n",
        "  def lower_fn(x, y):\n",
        "    axis_name = arg_shardings[1].spec[rank-2][0]\n",
        "    i = jax.lax.axis_index(axis_name)\n",
        "    z = jax.lax.psum(jax.lax.dynamic_slice_in_dim(\n",
        "        jax.lax.dynamic_slice_in_dim(x, i * 0, N, axis=rank-2), i * N, N, axis=rank-1) @ y,\n",
        "        (axis_name))\n",
        "    return z\n",
        "\n",
        "  return mesh, lower_fn, (result_sharding), arg_shardings\n",
        "\n",
        "@partial(custom_partitioning)\n",
        "def f(x, y):\n",
        "  return jnp.matmul(x, y)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "TNpRzba6QZpI"
      },
      "cell_type": "markdown",
      "source": [
        "Then, we invoke the `def_partition` API. Note that instead of providing two callbacks for parameters `infer_sharding_from_operands` and `propagate_user_sharding` as we would do with GSPMD, we provide a `sharding_rule` parameter, which is an [Einsum-like notation string](https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/) similar to the subscripts in `jnp.einsum(\"...ij, ...jk-\u003e...ik\", x, y)`, if we would extend `jnp.einsum` to support the use of `...` for representing leading batching dimensions. We borrow the idea from [einops.rearrange](https://einops.rocks/api/rearrange/) string, to use a space separator between factors (to allow multiple letters factor names) and  to not specify the value for a factor that ever represents a whole tensor dimension. We also support rank polymorphism by allowing leading ... in each tensor dimension representation to represent any number leading dimensions. "
      ]
    },
    {
      "metadata": {
        "id": "-KgDCVS2QZpI"
      },
      "cell_type": "code",
      "source": [
        "f.def_partition(\n",
        "  infer_sharding_from_operands=None,\n",
        "  propagate_user_sharding=None,\n",
        "  partition=partition,\n",
        "  sharding_rule=\"... i j, ... j k -\u003e ... i k\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "LXLuB_oYQZpI"
      },
      "cell_type": "markdown",
      "source": [
        "Alternatively, we can also create an equivalent `SdyShardingRule` object for the `sharding_rule` parameter. See [Shardy document on sharding rule](https://github.com/openxla/shardy/blob/main/docs/propagation.md#operation-sharding-rule) for more details."
      ]
    },
    {
      "metadata": {
        "id": "3IazxrurQZpI"
      },
      "cell_type": "code",
      "source": [
        "f.def_partition(\n",
        "  infer_sharding_from_operands=None,\n",
        "  propagate_user_sharding=None,\n",
        "  partition=partition,\n",
        "  sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i', 'j'), (BATCHING, 'j', 'k')), result_mappings=((BATCHING, 'i', 'k'),)))"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "UGZOVTWfQZpI"
      },
      "cell_type": "markdown",
      "source": [
        "The `sharding_rule` parameter can also take a callback function generating either a string or an `SdyShardingRule` object."
      ]
    },
    {
      "metadata": {
        "id": "c9Ter9_OQZpI"
      },
      "cell_type": "code",
      "source": [
        "def sharding_rule_producer(mesh, arg_shapes, result_shape):\n",
        "  rank = len(arg_shapes[0].shape)\n",
        "  leading_axes = \"\"\n",
        "  for i in range(rank - 2):\n",
        "    leading_axes += f\" b{i}\"\n",
        "  return f\"{leading_axes} i j, {leading_axes} j k -\u003e {leading_axes} i k\"\n",
        "\n",
        "\n",
        "f.def_partition(\n",
        "  partition=partition,\n",
        "  sharding_rule=sharding_rule_producer)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "_8bQf5XEQZpI"
      },
      "cell_type": "markdown",
      "source": [
        "Lastly, we create a mesh, define the input matrixes x and y, run the jitted f, and compare the results producted by the unjitted and the jitted f."
      ]
    },
    {
      "metadata": {
        "id": "rTDWqaCSQZpI"
      },
      "cell_type": "code",
      "source": [
        "N = 4\n",
        "M = 2\n",
        "num_devices = 2 * M * M\n",
        "\n",
        "devices = np.array(list(jax.devices())[:num_devices])\n",
        "if devices.size \u003c num_devices:\n",
        "  raise ValueError(f'Requires {num_devices} devices')\n",
        "device_mesh = Mesh(devices.reshape((2 * M, M)), ('x', 'y'))\n",
        "\n",
        "sharding_x = NamedSharding(device_mesh, PartitionSpec(None, None, 'x'))\n",
        "sharding_y = NamedSharding(device_mesh, PartitionSpec(None, None, 'y'))\n",
        "jitted_f = jax.jit(f, in_shardings=(sharding_x, sharding_y), out_shardings=sharding_x)\n",
        "\n",
        "x = np.asarray(np.random.randint(0, 20, (2, 3, 4*N, 2*N)), dtype=np.float32)\n",
        "y = np.asarray(np.random.randint(0, 20, (2, 3, 2*N, 4*N)), dtype=np.float32)\n",
        "\n",
        "result = f(x, y)\n",
        "\n",
        "with device_mesh:\n",
        "  jitted_result = jitted_f(x, y)\n",
        "\n",
        "for i in range(num_devices):\n",
        "  j = (i // M) * N\n",
        "  assert((np.asarray(jitted_result.addressable_shards[i].data) == result[:,:,j:j+N,:]).all())"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "Koj_x61PfIGi"
      },
      "cell_type": "markdown",
      "source": [
        "A factor usually corresponds to an elementwise or passthrough dimension, and has a default passthrough type. Keyword arguments can be used to specify other factor types. Currently, Shardy supports the following factor types: `reduction`, `need_replication`, and `permutation`. A typical example of `reduction` factor is the contracting dimensions in dot operations. In the matmul example above, if we partition factor `j` in the input, we have to insert an all-reduce for the output since this factor needs reduction. Here is how we can specify factor `j` as a `reduction` factor when using a string sharding rule:"
      ]
    },
    {
      "metadata": {
        "id": "Da7Ah54ChtW2"
      },
      "cell_type": "code",
      "source": [
        "f.def_partition(\n",
        "  infer_sharding_from_operands=None,\n",
        "  propagate_user_sharding=None,\n",
        "  partition=partition,\n",
        "  sharding_rule=\"... i j, ... j k -\u003e ... i k\", reduction_factors=('j',))"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "8lWtPhwqh8b8"
      },
      "cell_type": "markdown",
      "source": [
        "If `SdyShardingRule` object is used, we add `reduction_factors` as a keyword argument to the object construction:"
      ]
    },
    {
      "metadata": {
        "id": "APFIw19siOLC"
      },
      "cell_type": "code",
      "source": [
        "f.def_partition(\n",
        "  infer_sharding_from_operands=None,\n",
        "  propagate_user_sharding=None,\n",
        "  partition=partition,\n",
        "  sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i', 'j'), (BATCHING, 'j', 'k')), result_mappings=((BATCHING, 'i', 'k'),), reduction_factors=('j',)))"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "EZ81ku8WiXZS"
      },
      "cell_type": "markdown",
      "source": [
        "If a callback function generating a sharding rule string is used, we make the callback function also generate a dictionary containing the keyword argument and a value for the argument:"
      ]
    },
    {
      "metadata": {
        "id": "eA81kmYsgidX"
      },
      "cell_type": "code",
      "source": [
        "def sharding_rule_producer(mesh, arg_shapes, result_shape):\n",
        "  rank = len(arg_shapes[0].shape)\n",
        "  leading_axes = \"\"\n",
        "  for i in range(rank - 2):\n",
        "    leading_axes += f\" b{i}\"\n",
        "  return f\"{leading_axes} i j, {leading_axes} j k -\u003e {leading_axes} i k\", dict(reduction_factors=('j',))\n",
        "\n",
        "\n",
        "f.def_partition(\n",
        "  partition=partition,\n",
        "  sharding_rule=sharding_rule_producer)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "MrdZEzjiKmHH"
      },
      "cell_type": "markdown",
      "source": [
        "A factor may correspond to multiple dimension sizes, in this case, the smallest dimension size is used in the sharding rule. Here is a toy example to demonstrate this."
      ]
    },
    {
      "metadata": {
        "id": "08y2_vyOK3P9"
      },
      "cell_type": "code",
      "source": [
        "def partition(mesh, arg_shapes, result_shape):\n",
        "  arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)\n",
        "  result_sharding = result_shape.sharding\n",
        "\n",
        "  def lower_fn(x, y):\n",
        "    return x.reshape((4,)) + y\n",
        "\n",
        "  return mesh, lower_fn, (result_sharding), arg_shardings\n",
        "\n",
        "@partial(custom_partitioning)\n",
        "def f(x, y):\n",
        "  x = x.reshape((8,))\n",
        "  return x + y\n",
        "\n",
        "f.def_partition(\n",
        "  infer_sharding_from_operands=None,\n",
        "  propagate_user_sharding=None,\n",
        "  partition=partition,\n",
        "  sharding_rule=SdyShardingRule(operand_mappings=(('i', 'j'), ('i', )), result_mappings=(('i',),)))\n",
        "\n",
        "num_devices = 8\n",
        "devices = np.array(list(jax.devices())[:num_devices])\n",
        "if devices.size \u003c num_devices:\n",
        "  raise ValueError(f'Requires {num_devices} devices')\n",
        "device_mesh = Mesh(devices.reshape((2, 4)), ('x', 'y'))\n",
        "\n",
        "sharding_x = NamedSharding(device_mesh, PartitionSpec('x', None))\n",
        "sharding_y = NamedSharding(device_mesh, PartitionSpec('x', ))\n",
        "jitted_f = jax.jit(f, in_shardings=(sharding_x, sharding_y), out_shardings=sharding_y)\n",
        "\n",
        "x = np.arange(8).reshape(4, 2)\n",
        "y = np.arange(8)\n",
        "\n",
        "result = f(x, y)\n",
        "\n",
        "print(jitted_f.lower(x,y).as_text())\n",
        "with device_mesh:\n",
        "  jitted_result = jitted_f(x, y)\n",
        "\n",
        "for i in range(num_devices):\n",
        "  j = i // 4\n",
        "  assert((np.asarray(jitted_result.addressable_shards[i].data) == result[j*4:j*4+4]).all())"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "module @jit__unnamed_wrapped_function attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  sdy.mesh @mesh = \u003c[\"x\"=2, \"y\"=4]\u003e\n",
            "  func.func public @main(%arg0: tensor\u003c4x2xi32\u003e {sdy.sharding = #sdy.sharding\u003c@mesh, [{\"x\"}, {}]\u003e}, %arg1: tensor\u003c8xi32\u003e {sdy.sharding = #sdy.sharding\u003c@mesh, [{\"x\"}]\u003e}) -\u003e (tensor\u003c8xi32\u003e {jax.result_info = \"result\", sdy.sharding = #sdy.sharding\u003c@mesh, [{\"x\"}]\u003e}) {\n",
            "    %0 = stablehlo.custom_call @CustomSPMDPartitioning(%arg0, %arg1) {api_version = 2 : i32, backend_config = \"126939381919824\", sdy.sharding_rule = #sdy.op_sharding_rule\u003c([i, j], [i])-\u003e([i]) {i=4, j=2}, custom\u003e} : (tensor\u003c4x2xi32\u003e, tensor\u003c8xi32\u003e) -\u003e tensor\u003c8xi32\u003e\n",
            "    return %0 : tensor\u003c8xi32\u003e\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "BSCsZoKrpYTH"
      },
      "cell_type": "markdown",
      "source": [
        "Here is an example of `need_replication` factor. In this example, the function for custom partitioning is a wrapper of jnp.sort. The dimension being sorted need to be replicated in order to implement the operation. For this reason, we tell the partitioner that the factor corresponds to the sorting dimension is a `need_replication` factor."
      ]
    },
    {
      "metadata": {
        "id": "CTIrLg8Apk8D"
      },
      "cell_type": "code",
      "source": [
        "def partition(mesh, arg_shapes, result_shape):\n",
        "  arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)\n",
        "  result_sharding = result_shape.sharding\n",
        "\n",
        "  def lower_fn(x):\n",
        "    return jnp.sort(x, axis=0)\n",
        "\n",
        "  return mesh, lower_fn, (result_sharding), arg_shardings\n",
        "\n",
        "@partial(custom_partitioning)\n",
        "def f(x):\n",
        "  return jnp.sort(x, axis=0)\n",
        "\n",
        "f.def_partition(\n",
        "  infer_sharding_from_operands=None,\n",
        "  propagate_user_sharding=None,\n",
        "  partition=partition,\n",
        "  sharding_rule=\"i j -\u003e i j\", need_replication_factors=(\"i\",))\n",
        "\n",
        "devices = np.array(list(jax.devices())[:2])\n",
        "if devices.size \u003c 2:\n",
        "  raise ValueError('Requires 2 devices')\n",
        "device_mesh = Mesh(devices, ('x',))\n",
        "sharding = NamedSharding(device_mesh, PartitionSpec(None, 'x'))\n",
        "jitted_f = jax.jit(f, in_shardings=(sharding), out_shardings=sharding)\n",
        "\n",
        "input_data = np.asarray(np.random.randint(0, 10, (4, 8)), dtype=jnp.float32)\n",
        "result = f(input_data)\n",
        "\n",
        "with device_mesh:\n",
        "  jitted_result = jitted_f(input_data)\n",
        "\n",
        "assert((np.asarray(jitted_result.addressable_shards[0].data) == result[:,:4]).all())\n",
        "assert((np.asarray(jitted_result.addressable_shards[1].data) == result[:,4:]).all())"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "7_hpKD6oprA7"
      },
      "cell_type": "markdown",
      "source": [
        "Below is an example for `permutation` factor. In this example, the function for custom partitioning is a wrapper of jnp.flip on dimension 0. If this dimension is partitioned, we would need collective permute operation to gather the data in order to implement the reverse operation. As such, we tell the partitioner that the factor corresponds to dimension 0 is a `permutation` factor."
      ]
    },
    {
      "metadata": {
        "id": "KgavSM-ap21C"
      },
      "cell_type": "code",
      "source": [
        "def partition(mesh, arg_shapes, result_shape):\n",
        "  arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)\n",
        "  result_sharding = result_shape.sharding\n",
        "\n",
        "  def lower_fn(x):\n",
        "    return jnp.flip(x, axis=0)\n",
        "\n",
        "  return mesh, lower_fn, (result_sharding), arg_shardings\n",
        "\n",
        "@partial(custom_partitioning)\n",
        "def f(x):\n",
        "  return jnp.flip(x, axis=0)\n",
        "\n",
        "f.def_partition(\n",
        "  infer_sharding_from_operands=None,\n",
        "  propagate_user_sharding=None,\n",
        "  partition=partition,\n",
        "  sharding_rule=\"i j -\u003e i j\", permutation_factors=(\"i\",))\n",
        "\n",
        "devices = np.array(list(jax.devices())[:2])\n",
        "if devices.size \u003c 2:\n",
        "  raise ValueError('Requires 2 devices')\n",
        "device_mesh = Mesh(devices, ('x',))\n",
        "sharding = NamedSharding(device_mesh, PartitionSpec(None, 'x'))\n",
        "jitted_f = jax.jit(f, in_shardings=(sharding), out_shardings=sharding)\n",
        "\n",
        "input_data = np.asarray(np.random.randint(0, 10, (4, 8)), dtype=jnp.float32)\n",
        "result = f(input_data)\n",
        "\n",
        "with device_mesh:\n",
        "  jitted_result = jitted_f(input_data)\n",
        "\n",
        "assert((np.asarray(jitted_result.addressable_shards[0].data) == result[:,:4]).all())\n",
        "assert((np.asarray(jitted_result.addressable_shards[1].data) == result[:,4:]).all())"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "_UfLgKHie0ai"
      },
      "cell_type": "markdown",
      "source": [
        "### Auto partitioners\n",
        "\n",
        "In progress."
      ]
    },
    {
      "metadata": {
        "id": "JTR9tmEJnM6v"
      },
      "cell_type": "markdown",
      "source": [
        "### XLA_DUMP_TO\n",
        "\n",
        "When specifying the `XLA_DUMP_TO`, you will see an additional `shardy/` directory containing various dumps of the StableHLO program. A lot of them are currently only relevant to the Shardy team to debug issues. The one you should focus on when debugging is `sdy_module_after_sdy_export.mlir` which is the module after propagation finished on the StableHLO program.\n",
        "\n"
      ]
    },
    {
      "metadata": {
        "id": "0kIyBosUoT_j"
      },
      "cell_type": "markdown",
      "source": [
        "## 2. Why Shardy?"
      ]
    },
    {
      "metadata": {
        "id": "3EdCWagVNiuF"
      },
      "cell_type": "markdown",
      "source": [
        "### Readability\n",
        "\n",
        "As seen above, it's much easier to read the shardings and shard_maps and understand how they match what is happening in the JAX code. Similarly GSPMD propagation will give back HLO code - not MLIR which both Shardy and `jax.jit.lower` return."
      ]
    },
    {
      "metadata": {
        "id": "si1rYISKoYNO"
      },
      "cell_type": "markdown",
      "source": [
        "### Interpretability\n",
        "\n",
        "We are planning on exposing a feature we call \"user priorities\" (not in JAX yet!). It allows you to attach a value telling Shardy how important a tensor's dimension sharding is over other constraints in the program.\n",
        "\n",
        "Higher prioritied are defines as lower values (lowest being 0, think of it as a p0 priority tasks).\n",
        "\n",
        "```python\n",
        "PartitionSpec(None, 'x', 'y', priorities=(None, 0, 1))\n",
        "```\n",
        "\n",
        "Here the sharding of dim 1 on `x` has a higher priority than dim 2 on `y`, meaning dim 1 will be propagated through the program first and then dim 2, meaning any potential sharding conflicts will be explicitly avoided by having `x` propagated first.\n",
        "\n",
        "This can be helpful for debugging models as well by having you break down your sharding strategies to separate rounds of propagation in Shardy. For example:\n",
        "\n",
        "* Priority 0: data parallelism\n",
        "* Priority 1: megatron\n",
        "* Priority 2: ZeRO sharding"
      ]
    },
    {
      "metadata": {
        "id": "cR9YOSVeUSyy"
      },
      "cell_type": "markdown",
      "source": [
        "## FAQS\n",
        "\n",
        "Below is a list of questions you may have on various JAX features and capabilities.\n",
        "\n",
        "### JAX Sharding types\n",
        "\n",
        "#### What about GSPMDSharding?\n",
        "\n",
        "`GSPMDSharding`  is closely tied to the C++/protobuf representation inside the XLA compiler. As such the type itself won't be supported.\n",
        "\n",
        "#### What about PositionalSharding?\n",
        "\n",
        "This won't be supported. Instead use a `NamedSharding` with `device_ids`.\n",
        "\n",
        "#### PmapSharding\n",
        "\n",
        "This won't be supported. Shardy is meant for `jax.jit`, not `jax.pmap`.\n",
        "\n",
        "### Propagation Questions\n",
        "\n",
        "Section for questions about what you may see during propagation.\n",
        "\n",
        "#### What are split Axes in Shardy, aka \"x\":(2)2?\n",
        "\n",
        "Refer to [Axis splitting and sub-axes](sharding_representation.md#axis-splitting-and-sub-axes).\n"
      ]
    },
    {
      "metadata": {
        "id": "LqNbAQkbRege"
      },
      "cell_type": "markdown",
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "gpuType": "V28",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
